import torch.nn as nn

from centralized_verification.models.simple_mlp import SimpleMLP


class NormSimpleMLP(SimpleMLP):
    def __init__(self, obs_space, num_outputs: int):
        super().__init__(obs_space, num_outputs)

        for idx, layer in enumerate(self.net):
            if isinstance(layer, nn.Linear):
                non_linearity = "relu" if idx < len(self.net) - 1 else "linear"
                nn.init.xavier_normal_(layer.weight, nn.init.calculate_gain(non_linearity))
